package io

import (
	"io"
	"runtime"
	"sync"
	"sync/atomic"

	"gitee.com/yrwy/msgo/pkg/errors"
	msync "gitee.com/yrwy/msgo/pkg/sync"
)

type RingStream struct {
	msync.Closer
	wm   sync.Mutex //写锁
	rm   sync.Mutex //读锁
	buf  []byte     //缓冲区间
	wpos int32      //写位置
	rpos int32      //读位置
	size int32      //数据长度
	rch  chan int8  //读通知
	wch  chan int8  //写通知
}

func NewRingStream(size int) *RingStream {
	if size < 1024*1024 {
		size = 1024 * 1024
	}
	b := &RingStream{
		Closer: msync.NewCloser(),
		buf:    make([]byte, size),
		rch:    make(chan int8),
		wch:    make(chan int8),
	}
	runtime.SetFinalizer(b, func(b *RingStream) {
		close(b.rch)
		close(b.wch)
	})

	return b
}

func (b *RingStream) Len() int {
	return int(atomic.LoadInt32(&b.size))
}

func (b *RingStream) Write(p []byte) (int, error) {
	nr, err := b.TryWrite(p)
	if err != nil {
		return nr, err
	}
	if nr == len(p) {
		return nr, nil
	}
	for {
		//fmt.Println("Write wait wch:", b.size)
		select {
		case <-b.Done():
			return nr, errors.ErrClosed
		case <-b.wch:
			//fmt.Println("Write recv wch:", b.size)
			r, err := b.TryWrite(p[nr:])
			if err != nil {
				return nr, err
			}
			nr += r
			if nr == len(p) {
				return nr, nil
			}
		case b.rch <- 1:
			//fmt.Println("Write send rch:", b.size)
		}
	}
}

func (b *RingStream) TryWrite(p []byte) (int, error) {
	select {
	case <-b.Done():
		return 0, errors.ErrClosed
	default:
	}
	size := len(p)
	if size == 0 {
		return 0, nil
	}
	nr := 0
	b.wm.Lock()
	defer b.wm.Unlock()
	wpos := int(atomic.LoadInt32(&b.wpos))
	rpos := int(atomic.LoadInt32(&b.rpos))
	c := atomic.LoadInt32(&b.size)
	if wpos > rpos || c == 0 {
		//可写位置[wpos,len(buf))+[0,rpos)
		nr = copy(b.buf[wpos:], p)
		wpos += nr
		if wpos == len(b.buf) {
			wpos = 0
		}
		if nr < size {
			r := copy(b.buf[:rpos], p[nr:])
			nr += r
			wpos += r
		}
	} else if rpos-wpos > 0 {
		//可写位置[wpos,rpos)
		nr = copy(b.buf[wpos:rpos], p)
		wpos += nr
	}
	if nr > 0 {
		atomic.StoreInt32(&b.wpos, int32(wpos))
		atomic.AddInt32(&b.size, int32(nr))
	}
	//设置可读
	select {
	case b.rch <- 1:
		//fmt.Println("try send rch:", b.size)
	default:
	}
	return nr, nil
}

func (b *RingStream) Read(buf []byte) (int, error) {
	size := len(buf)
	if size == 0 {
		return 0, nil
	}
	return b.ReadBytes(size, func(p []byte) {
		r := copy(buf, p)
		buf = buf[r:]
	})
}

//如果n<=0,全部读取
func (b *RingStream) ReadBytes(n int, cb func([]byte)) (int, error) {
	nr := b.TryReadBytes(n, cb)
	if nr > 0 {
		return nr, nil
	}
	select {
	case <-b.Done():
		return 0, io.EOF
	default:
	}
	for {
		//wait data
		select {
		case <-b.Done():
			return 0, io.EOF
		case <-b.rch:
			//has data write
			nr := b.TryReadBytes(n, cb)
			if nr > 0 {
				return nr, nil
			}
		case b.wch <- 2:
			//send can write
		}
	}
}

//如果n<=0,全部读取
func (b *RingStream) TryReadBytes(n int, cb func([]byte)) int {
	b.rm.Lock()
	defer b.rm.Unlock()
	wpos := int(atomic.LoadInt32(&b.wpos))
	rpos := int(atomic.LoadInt32(&b.rpos))
	c := int(atomic.LoadInt32(&b.size))
	if n <= 0 {
		n = c
	}
	nr := 0
	if rpos < wpos {
		//可读位置[rpos,wpos)
		if n > wpos-rpos {
			n = wpos - rpos
		}
		cb(b.buf[rpos : rpos+n])
		rpos += n
		nr = n

	} else if rpos > wpos || c > 0 {
		//可读位置[rpos,len(buf))+[0,wpos)
		nr = len(b.buf) - rpos
		if nr > n {
			nr = n
		}
		cb(b.buf[rpos : rpos+nr])
		rpos += nr
		if nr < n {
			n -= nr
			if n > wpos {
				n = wpos
			}
			cb(b.buf[:n])
			nr += n
			rpos = n
		} else if rpos == len(b.buf) {
			rpos = 0
		}
	}
	if nr > 0 {
		atomic.StoreInt32(&b.rpos, int32(rpos))
		atomic.AddInt32(&b.size, -int32(nr))
	}
	//设置可写
	select {
	case b.wch <- 2:
		//send can write
	default:
	}
	return nr
}
